#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on 2022/03/09

@author: Feng Sheng
"""

# -*- coding: utf-8 -*-
"""
Spyder Editor

"""

import pickle
import hddm
import pandas as pd
import matplotlib.pyplot as plt
from kabuki.analyze import *
import pylab
import numpy
import os

root_path=os.getcwd()

data = hddm.load_csv('./trial_hddm.csv')

#%% DDM4a - role effect on lottery weighting, z, a, intercept

os.chdir(root_path +'\\DDM4a')
DDM4a = hddm.HDDMRegressor(data, ["v ~ C(Role) + Price + Lottery:C(Role)", "z ~ C(Role)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
DDM4a.find_starting_values()
DDM4a.sample(12000, burn=2000, dbname='traces.db',db='pickle')

DDM4a.print_stats()
DDM4a.plot_posteriors()
DDM4a.save('DDM4a')

DDM4a_trace = DDM4a.get_traces()
DDM4a_trace.to_csv("./DDM4a_trace.csv")

DDM4a_post=DDM4a.nodes_db
DDM4a_post.to_csv("./DDM4a_post.csv")

DDM4a=hddm.load('DDM4a')
stats = DDM4a.gen_stats()
stats.to_csv("DDM4a_stats.csv") 

temp=np.zeros((64,17))
post_tab=pd.DataFrame({'aBuy': temp[:, 0], \
                       'aSell': temp[:, 1], \
                       'aDelta': temp[:, 2], \
                       't': temp[:, 3], \
                       'zBuy': temp[:, 4], \
                       'zSell': temp[:, 5], \
                       'zDelta': temp[:, 6], \
                       'vPrice': temp[:, 7], \
                       'vLotteryBuy': temp[:, 8], \
                       'vLotterySell': temp[:, 9], \
                       'wLotteryBuy': temp[:, 10], \
                       'wLotterySell': temp[:, 11], \
                       'vInterceptBuy': temp[:, 12], \
                       'vInterceptSell': temp[:, 13], \
                       'vInterceptDelta': temp[:, 14], \
                       'ValBias': temp[:, 15], \
                       'ResBias': temp[:, 16],})

for isub in numpy.unique(data.subj_idx):    
    post_tab.aBuy[isub-1]=DDM4a_post.loc[ 'a_Intercept_subj.'+str(isub),'mean' ]
    post_tab.aDelta[isub-1]=DDM4a_post.loc[ 'a_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    post_tab.t[isub-1]=DDM4a_post.loc[ 't_subj.'+str(isub),'mean' ]
    post_tab.zBuy[isub-1]=DDM4a_post.loc[ 'z_Intercept_subj.'+str(isub),'mean' ]
    post_tab.zDelta[isub-1]=DDM4a_post.loc[ 'z_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    post_tab.vPrice[isub-1]=DDM4a_post.loc[ 'v_Price_subj.'+str(isub),'mean' ] 
    post_tab.vLotteryBuy[isub-1]=DDM4a_post.loc[ 'v_Lottery:C(Role)[1]_subj.'+str(isub),'mean' ]
    post_tab.vLotterySell[isub-1]=DDM4a_post.loc[ 'v_Lottery:C(Role)[2]_subj.'+str(isub),'mean' ]
    post_tab.vInterceptBuy[isub-1]=DDM4a_post.loc[ 'v_Intercept_subj.'+str(isub),'mean' ]
    post_tab.vInterceptDelta[isub-1]=DDM4a_post.loc[ 'v_C(Role)[T.2]_subj.'+str(isub),'mean' ]
post_tab.aSell=post_tab.aBuy+post_tab.aDelta
post_tab.zSell=post_tab.zBuy+post_tab.zDelta
post_tab.vInterceptSell=post_tab.vInterceptBuy+post_tab.vInterceptDelta
post_tab.wLotteryBuy=np.true_divide(post_tab.vLotteryBuy, -post_tab.vPrice)
post_tab.wLotterySell=np.true_divide(post_tab.vLotterySell, -post_tab.vPrice)
post_tab.ValBias=np.true_divide(post_tab.wLotterySell, post_tab.wLotteryBuy)
post_tab.ResBias=post_tab.zSell - post_tab.zBuy

post_tab.to_csv("./DDM4a_post_subject.csv")

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=DDM4a.dic
ddm_dic.to_csv("./DDM4a_dic.csv")


# simulate choice and rt

os.chdir(root_path+'\\DDM4a\\')
DDM4a=hddm.load('DDM4a')

data['RTSim']=np.zeros(data.shape[0])

err=0.0001

post=DDM4a.nodes_db

for idx in range(0,data.shape[0]):
    
    isub = data.subj_idx[idx]
    irole = data.Role[idx]
    
    t = post.loc[ 't_subj.'+str(isub),'mean' ]
    a = post.loc[ 'a_Intercept_subj.'+str(isub),'mean' ] + (irole-1)*post.loc[ 'a_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    z = post.loc[ 'z_Intercept_subj.'+str(isub),'mean' ] + (irole-1)*post.loc[ 'z_C(Role)[T.2]_subj.'+str(isub),'mean' ]

    v_Intercept = post.loc[ 'v_Intercept_subj.'+str(isub),'mean' ] + (irole-1)*post.loc[ 'v_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    v_Buy = post.loc[ 'v_Lottery:C(Role)[1]_subj.'+str(isub),'mean' ] * data.Lottery[idx] + post.loc[ 'v_Price_subj.'+str(isub),'mean' ] * data.Price[idx]
    v_Sell = post.loc[ 'v_Lottery:C(Role)[2]_subj.'+str(isub),'mean' ] * data.Lottery[idx] + post.loc[ 'v_Price_subj.'+str(isub),'mean' ] * data.Price[idx]
    v = v_Intercept + (2-irole) * v_Buy + (irole-1) * v_Sell

    rt=np.arange(-10.000,10.001,0.001)
    pdf=hddm.wfpt.pdf_array(rt, v, 0, a, z, 0, t, 0, err)
    rt_pmax=rt[np.argmax(pdf)]
    
    data.RTSim[idx]=rt_pmax
  
data.to_csv(root_path+'\\DDM4a\\DDM4a_hddm_rt_simulated.csv')

#%% Test reduced models relative to DDM4a

# DDM2a - No role effect on lottery weighting
os.chdir(root_path +'\\DDM2a')
DDM2a = hddm.HDDMRegressor(data, ["v ~ C(Role) + Price + Lottery", "z ~ C(Role)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
DDM2a.find_starting_values()
DDM2a.sample(12000, burn=2000, dbname='traces.db',db='pickle')

DDM2a.print_stats()
DDM2a.plot_posteriors()
DDM2a.save('DDM2a')

DDM2a_trace = DDM2a.get_traces()
DDM2a_trace.to_csv("./DDM2a_trace.csv")

DDM2a_post=DDM2a.nodes_db
DDM2a_post.to_csv("./DDM2a_post.csv")

DDM2a=hddm.load('DDM2a')
stats = DDM2a.gen_stats()
stats.to_csv("DDM2a_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=DDM2a.dic
ddm_dic.to_csv("./DDM2a_dic.csv")


# DDM3a - no role effect on z
os.chdir(root_path +'\\DDM3a')
DDM3a = hddm.HDDMRegressor(data, ["v ~ C(Role) + Price + Lottery:C(Role)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
DDM3a.find_starting_values()
DDM3a.sample(12000, burn=2000, dbname='traces.db',db='pickle')

DDM3a.print_stats()
DDM3a.plot_posteriors()
DDM3a.save('DDM3a')

DDM3a_trace = DDM3a.get_traces()
DDM3a_trace.to_csv("./DDM3a_trace.csv")

DDM3a_post=DDM3a.nodes_db
DDM3a_post.to_csv("./DDM3a_post.csv")

DDM3a=hddm.load('DDM3a')
stats = DDM3a.gen_stats()
stats.to_csv("DDM3a_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=DDM3a.dic
ddm_dic.to_csv("./DDM3a_dic.csv")


# DDM4 - no role effect on intercept
os.chdir(root_path +'\\DDM4')
DDM4 = hddm.HDDMRegressor(data, ["v ~ Price + Lottery:C(Role)", "z ~ C(Role)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
DDM4.find_starting_values()
DDM4.sample(12000, burn=2000, dbname='traces.db',db='pickle')

DDM4.print_stats()
DDM4.plot_posteriors()
DDM4.save('DDM4')

DDM4_trace = DDM4.get_traces()
DDM4_trace.to_csv("./DDM4_trace.csv")

DDM4_post=DDM4.nodes_db
DDM4_post.to_csv("./DDM4_post.csv")

DDM4=hddm.load('DDM4')
stats = DDM4.gen_stats()
stats.to_csv("DDM4_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=DDM4.dic
ddm_dic.to_csv("./DDM4_dic.csv")


# DDM0a - no role effect on a
os.chdir(root_path +'\\DDM0a')
DDM0a = hddm.HDDMRegressor(data, ["v ~ C(Role) + Price + Lottery:C(Role)", "z ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
DDM0a.find_starting_values()
DDM0a.sample(12000, burn=2000, dbname='traces.db',db='pickle')

DDM0a.print_stats()
DDM0a.plot_posteriors()
DDM0a.save('DDM0a')

DDM0a_trace = DDM0a.get_traces()
DDM0a_trace.to_csv("./DDM0a_trace.csv")

DDM0a_post=DDM0a.nodes_db
DDM0a_post.to_csv("./DDM0a_post.csv")

DDM0a=hddm.load('DDM0a')
stats = DDM0a.gen_stats()
stats.to_csv("DDM0a_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=DDM0a.dic
ddm_dic.to_csv("./DDM0a_dic.csv")


#%% DDM4b - no intercept

os.chdir(root_path +'\\DDM4b')
DDM4b = hddm.HDDMRegressor(data, ["v ~ 0 + Price + Lottery:C(Role)", "z ~ C(Role)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
DDM4b.find_starting_values()
DDM4b.sample(12000, burn=2000, dbname='traces.db',db='pickle')

DDM4b.print_stats()
DDM4b.plot_posteriors()
DDM4b.save('DDM4b')

DDM4b_trace = DDM4b.get_traces()
DDM4b_trace.to_csv("./DDM4b_trace.csv")

DDM4b_post=DDM4b.nodes_db
DDM4b_post.to_csv("./DDM4b_post.csv")

DDM4b=hddm.load('DDM4b')
stats = DDM4b.gen_stats()
stats.to_csv("DDM4b_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=DDM4b.dic
ddm_dic.to_csv("./DDM4b_dic.csv")

#%%
